{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Application Architectures Deep Dive"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computer Vision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### cnn_learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'cut': -2,\n",
       " 'split': <function fastai2.vision.learner._resnet_split(m)>,\n",
       " 'stats': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_meta[resnet50]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): AdaptiveConcatPool2d(\n",
       "    (ap): AdaptiveAvgPool2d(output_size=1)\n",
       "    (mp): AdaptiveMaxPool2d(output_size=1)\n",
       "  )\n",
       "  (1): full: False\n",
       "  (2): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Dropout(p=0.25, inplace=False)\n",
       "  (4): Linear(in_features=20, out_features=512, bias=False)\n",
       "  (5): ReLU(inplace=True)\n",
       "  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (7): Dropout(p=0.5, inplace=False)\n",
       "  (8): Linear(in_features=512, out_features=2, bias=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "create_head(20,2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### unet_learner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A Siamese Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from fastai2.vision.all import *\n",
    "path = untar_data(URLs.PETS)\n",
    "files = get_image_files(path/\"images\")\n",
    "\n",
    "class SiameseImage(Tuple):\n",
    "    def show(self, ctx=None, **kwargs): \n",
    "        img1,img2,same_breed = self\n",
    "        if not isinstance(img1, Tensor):\n",
    "            if img2.size != img1.size: img2 = img2.resize(img1.size)\n",
    "            t1,t2 = tensor(img1),tensor(img2)\n",
    "            t1,t2 = t1.permute(2,0,1),t2.permute(2,0,1)\n",
    "        else: t1,t2 = img1,img2\n",
    "        line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)\n",
    "        return show_image(torch.cat([t1,line,t2], dim=2), \n",
    "                          title=same_breed, ctx=ctx)\n",
    "    \n",
    "def label_func(fname):\n",
    "    return re.match(r'^(.*)_\\d+.jpg$', fname.name).groups()[0]\n",
    "\n",
    "class SiameseTransform(Transform):\n",
    "    def __init__(self, files, label_func, splits):\n",
    "        self.labels = files.map(label_func).unique()\n",
    "        self.lbl2files = {l: L(f for f in files if label_func(f) == l) for l in self.labels}\n",
    "        self.label_func = label_func\n",
    "        self.valid = {f: self._draw(f) for f in files[splits[1]]}\n",
    "        \n",
    "    def encodes(self, f):\n",
    "        f2,t = self.valid.get(f, self._draw(f))\n",
    "        img1,img2 = PILImage.create(f),PILImage.create(f2)\n",
    "        return SiameseImage(img1, img2, t)\n",
    "    \n",
    "    def _draw(self, f):\n",
    "        same = random.random() < 0.5\n",
    "        cls = self.label_func(f)\n",
    "        if not same: cls = random.choice(L(l for l in self.labels if l != cls)) \n",
    "        return random.choice(self.lbl2files[cls]),same\n",
    "    \n",
    "splits = RandomSplitter()(files)\n",
    "tfm = SiameseTransform(files, label_func, splits)\n",
    "tls = TfmdLists(files, tfm, splits=splits)\n",
    "dls = tls.dataloaders(after_item=[Resize(224), ToTensor], \n",
    "    after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SiameseModel(Module):\n",
    "    def __init__(self, encoder, head):\n",
    "        self.encoder,self.head = encoder,head\n",
    "    \n",
    "    def forward(self, x1, x2):\n",
    "        ftrs = torch.cat([self.encoder(x1), self.encoder(x2)], dim=1)\n",
    "        return self.head(ftrs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = create_body(resnet34, cut=-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "head = create_head(512*4, 2, ps=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SiameseModel(encoder, head)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_func(out, targ):\n",
    "    return nn.CrossEntropyLoss()(out, targ.long())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def siamese_splitter(model):\n",
    "    return [params(model.encoder), params(model.head)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(dls, model, loss_func=loss_func, \n",
    "                splitter=siamese_splitter, metrics=accuracy)\n",
    "learn.freeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.367015</td>\n",
       "      <td>0.281242</td>\n",
       "      <td>0.885656</td>\n",
       "      <td>00:26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.307688</td>\n",
       "      <td>0.214721</td>\n",
       "      <td>0.915426</td>\n",
       "      <td>00:26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.275221</td>\n",
       "      <td>0.170615</td>\n",
       "      <td>0.936401</td>\n",
       "      <td>00:26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.223771</td>\n",
       "      <td>0.159633</td>\n",
       "      <td>0.943843</td>\n",
       "      <td>00:26</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(4, 3e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.212744</td>\n",
       "      <td>0.159033</td>\n",
       "      <td>0.944520</td>\n",
       "      <td>00:35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.201893</td>\n",
       "      <td>0.159615</td>\n",
       "      <td>0.942490</td>\n",
       "      <td>00:35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.204606</td>\n",
       "      <td>0.152338</td>\n",
       "      <td>0.945196</td>\n",
       "      <td>00:36</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.213203</td>\n",
       "      <td>0.148346</td>\n",
       "      <td>0.947903</td>\n",
       "      <td>00:36</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.unfreeze()\n",
    "learn.fit_one_cycle(4, slice(1e-6,1e-4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Natural Language Processing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tabular"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Wrapping Up Architectures"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Questionnaire"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. What is the \"head\" of a neural net?\n",
    "1. What is the \"body\" of a neural net?\n",
    "1. What is \"cutting\" a neural net? Why do we need to do this for transfer learning?\n",
    "1. What is `model_meta`? Try printing it to see what's inside.\n",
    "1. Read the source code for `create_head` and make sure you understand what each line does.\n",
    "1. Look at the output of `create_head` and make sure you understand why each layer is there, and how the `create_head` source created it.\n",
    "1. Figure out how to change the dropout, layer size, and number of layers created by `create_cnn`, and see if you can find values that result in better accuracy from the pet recognizer.\n",
    "1. What does `AdaptiveConcatPool2d` do?\n",
    "1. What is \"nearest neighbor interpolation\"? How can it be used to upsample convolutional activations?\n",
    "1. What is a \"transposed convolution\"? What is another name for it?\n",
    "1. Create a conv layer with `transpose=True` and apply it to an image. Check the output shape.\n",
    "1. Draw the U-Net architecture.\n",
    "1. What is \"BPTT for Text Classification\" (BPT3C)?\n",
    "1. How do we handle different length sequences in BPT3C?\n",
    "1. Try to run each line of `TabularModel.forward` separately, one line per cell, in a notebook, and look at the input and output shapes at each step.\n",
    "1. How is `self.layers` defined in `TabularModel`?\n",
    "1. What are the five steps for preventing over-fitting?\n",
    "1. Why don't we reduce architecture complexity before trying other approaches to preventing overfitting?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Further Research"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Write your own custom head and try training the pet recognizer with it. See if you can get a better result than fastai's default.\n",
    "1. Try switching between `AdaptiveConcatPool2d` and `AdaptiveAvgPool2d` in a CNN head and see what difference it makes.\n",
    "1. Write your own custom splitter to create a separate parameter group for every ResNet block, and a separate group for the stem. Try training with it, and see if it improves the pet recognizer.\n",
    "1. Read the online chapter about generative image models, and create your own colorizer, super-resolution model, or style transfer model.\n",
    "1. Create a custom head using nearest neighbor interpolation and use it to do segmentation on CamVid."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
